from typing import Optional, Union
import numpy as np

def rti(G: np.ndarray, k_max: Optional[int] = None) -> np.ndarray:
    """
    Refined Triangle Inequality (RTI) bound:
    Computes an upper bound on the mapping k ↦ max_{T ⊆ [n], |T| = k} ∑_{i,j ∈ T} G[i,j].

    Args:
        G: (n, n) symmetric input matrix (e.g., Gram matrix).
        k_max: Optional[int], if specified, compute bounds only up to size k_max.

    Returns:
        An array of length k_max (or n if k_max is None) with upper bounds for each k.
    """
    n, _ = G.shape
    if k_max is None:
        k_max = n

    # Copy and work in float64 for safety
    M = np.copy(G).astype(np.float64)

    # Step 1: For each column, keep only top k_max entries using partition
    idx = np.argpartition(-M, kth=k_max-1, axis=0)[:k_max, :]  # (k_max, n)
    col_topk = np.take_along_axis(M, idx, axis=0)
    col_topk_sorted = -np.sort(-col_topk, axis=0)  # descending sort
    M1 = np.cumsum(col_topk_sorted, axis=0)  # shape: (k_max, n)

    # Step 2: For each row, again take top k_max entries using partition
    idx2 = np.argpartition(-M1, kth=k_max-1, axis=1)[:, :k_max]  # (k_max, k_max)
    row_topk = np.take_along_axis(M1, idx2, axis=1)
    row_topk_sorted = -np.sort(-row_topk, axis=1)  # descending sort
    M2 = np.cumsum(row_topk_sorted, axis=1)  # shape: (k_max, k_max)

    return np.diag(M2)  # length k_max





def fractional_knapsack_rows(G: np.ndarray, w: np.ndarray, capacities: Optional[np.ndarray] = None) -> np.ndarray:
    """
    Fully vectorized fractional knapsack with flexible capacity array support.

    Args:
        G: (n, n) matrix of item values.
        w: (n,) vector of item weights.
        capacities: (m,) array of real-valued capacities to evaluate.

    Returns:
        F: (n, m) matrix where F[i, j] is the optimal value for row i at capacity capacities[j].
    """
    n = G.shape[0]

    if capacities is None:
        capacities = np.arange(1, n + 1)
    else:
        capacities = np.asarray(capacities)
        assert capacities.ndim == 1, "capacities must be a 1D array"

    m = len(capacities)
    dummy_weight = np.max(capacities) + np.sum(w) + 1.0

    # Add dummy item
    G_dummy = np.zeros((n, 1))
    w_dummy = np.full((1,), dummy_weight)

    G_aug = np.concatenate([G, G_dummy], axis=1)      # shape (n, n+1)
    w_aug = np.concatenate([w, w_dummy])              # shape (n+1,)

    # Compute value/weight ratios
    ratio = G_aug / w_aug[None, :]                    # shape (n, n+1)
    idx = np.argsort(-ratio, axis=1)                  # sort descending per row
    G_sorted = np.take_along_axis(G_aug, idx, axis=1)
    w_sorted = np.take_along_axis(w_aug[None, :], idx, axis=1)

    # Compute cumulative sums
    G_cumsum = np.cumsum(G_sorted, axis=1)
    w_cumsum = np.cumsum(w_sorted, axis=1)

    # Row-wise searchsorted for real-valued capacities
    full_item_counts = np.stack([
        np.searchsorted(w_cumsum[i], capacities, side='right')
        for i in range(n)
    ], axis=0)  # shape (n, m)

    # Full values
    full_vals = np.zeros((n, m))
    rows, cols = np.where(full_item_counts > 0)
    full_indices = full_item_counts[rows, cols] - 1
    full_vals[rows, cols] = G_cumsum[rows, full_indices]

    # Remaining capacity and fractional values
    taken_weight = np.take_along_axis(w_cumsum, full_item_counts, axis=1)
    remaining_capacity = capacities[None, :] - taken_weight
    remaining_capacity = np.maximum(remaining_capacity, 0)

    next_w = np.take_along_axis(w_sorted, full_item_counts, axis=1)
    next_v = np.take_along_axis(G_sorted, full_item_counts, axis=1)
    next_frac = np.divide(remaining_capacity, next_w, out=np.zeros_like(remaining_capacity), where=next_w < np.inf)

    fractional_vals = next_v * next_frac
    total_vals = full_vals + fractional_vals

    return total_vals



def weighted_rti(G: np.ndarray, w: np.ndarray, capacities: Optional[np.ndarray] = None) -> np.ndarray:
    """
    Weighted RTI: WRTI(G) = diag(FK(FK(G).T)), where FK is fractional knapsack,
    evaluated at the user-specified capacities.

    Args:
        G: (n, n) symmetric matrix.
        w: (n,) weight vector.
        capacities: (m,) vector of capacities to compute bounds for.

    Returns:
        diag_vals: array of length m with upper bounds for each capacity in `capacities`.
    """
    step1 = fractional_knapsack_rows(G, w, capacities)      # shape (n, m)
    step2 = fractional_knapsack_rows(step1.T, w, capacities)  # shape (m, m)
    return np.diag(step2)  # length m



from typing import Optional
import numpy as np

def sparse_pca_inv(G: np.ndarray, weights: Optional[np.ndarray] = None, capacities: Optional[np.ndarray] = None) -> np.ndarray:
    """
    Computes upper bounds on the maximum singular value of (I - G[T, T])^{-1}
    over all submatrices of G with total weight (or size) <= capacity.

    Args:
        G: (n, n) Gram matrix (assumed symmetric).
        weights: Optional[n] vector of weights for each coordinate. If None, unweighted RTI is used.
        capacities: Optional[m] array of capacities to evaluate the bounds at.
                    If None, default to capacities = [1, 2, ..., n].

    Returns:
        s: (m,) vector where s[j] is an upper bound for capacity capacities[j].
    """
    n = G.shape[0]

    # Diagonal of G
    Gamma = np.diag(G)

    if np.any(1 - Gamma <= 0):
        raise ValueError("Invalid G: all diagonal entries must be strictly less than 1.")

    # Compute R = diag(1 / (1 - Gamma))
    R_diag = 1 / (1 - Gamma)
    R = np.diag(R_diag)

    # Off-diagonal part of G
    G_hat = G - np.diag(Gamma)

    # A = R * G_hat
    A = R @ G_hat
    A_prime = np.square(np.abs(A))

    # Decide which RTI to use
    if capacities is None:
        capacities = np.arange(1, n + 1)

    if weights is None:
        bound1 = rti(A_prime, capacities=capacities)
        bound2 = rti(A_prime.T, capacities=capacities)
    else:
        bound1 = weighted_rti(A_prime, w=weights, capacities=capacities)
        bound2 = weighted_rti(A_prime.T, w=weights, capacities=capacities)

    # Compute f[k] = min(sqrt(bound1), sqrt(bound2))
    f = np.minimum(np.sqrt(bound1), np.sqrt(bound2))

    # Max singular value of R
    sigma_R = 1 / (1 - np.max(Gamma))

    # Compute final bound: s[k] = sigma_R / (1 - f[k])
    s = sigma_R / (1 - f)
    s[f >= 1] = np.inf

    return s
